package com.stlm2.util;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Method;

/**
 * 代码生成器

 */
public class CodeGenerator {

	//项目名称，生成前请检查
	private final static String PROJECT_NAME = "core";

	private final static String SRC_MENU = "src";

	private final static String MAVEN_SRC_MENU = "src" + File.separator + "main" + File.separator + "java";

	public static void main(String[] args) {
		String zhangyoo = "com" + File.separator + "stlm2" + File.separator + "dems";

		new CodeGenerator(zhangyoo, MAVEN_SRC_MENU);
	}

	private final static String BASE_DAO_NAME = "com.stlm2.core.base.BaseDao";

	private final static String BASE_DAO_IMPL_NAME = "com.stlm2.core.base.impl.BaseDaoImpl";

	private final static String BASE_SERVICE_IMPL_NAME = "com.stlm2.core.base.impl.BaseServiceImpl";

	private String basePackage;

	private String packageName;

	/**
	 * 包名
	 *
	 * @param packageName
	 */
	public CodeGenerator(String packageName, String srcMenu) {
		this.packageName = packageName;
		this.basePackage = this.getProjectPath() + File.separator + srcMenu
				+ File.separator + packageName;
		System.out.println(this.packageName);
		System.out.println(this.basePackage);
		generatorMapper(packageName);
		System.out.println(packageName + "下mapper文件生成完毕");
		generator();
		System.out.println(packageName + "下dao、service层文件生成完毕");
	}

	/**
	 * 遍历目录并生成代码
	 */
	private void generator() {
		String mapperPath = basePackage + File.separator + "mapper";

		File mapperPackage = new File(mapperPath);
		if (mapperPackage.exists() && mapperPackage.isDirectory()) {
			for (File file : mapperPackage.listFiles()) {
				String className = file.getName();
				className = className.substring(0, className.indexOf("."));

				generatorDao(className);
				generatorDaoImpl(className);
				generatorService(className);
			}
		}
	}

	/**
	 * 生成mapper文件
	 */
	private void generatorMapper(String packageName){
		String basePath = basePackage + File.separator + "entity" + File.separator + "base";
		File entityPackage = new File(basePath);
		if (entityPackage.exists() && entityPackage.isDirectory()) {
			for (File file : entityPackage.listFiles()) {
				String className = file.getName();
				className = className.substring(0, className.indexOf("."));

				try {
					Class clazz = Class.forName(packageName.replace(File.separator, ".") + ".entity." + className.substring(4));
					if(clazz == null)
						continue;
					String entityName = clazz.getName();//namespace	{1}
					String simpleName = clazz.getSimpleName();//type	{2}

					//如果mapper文件已经存在，则跳过，防止覆盖
					String mapperPath = basePackage + File.separator + "mapper";
					File mapperPathFile = new File(mapperPath);
					if (!mapperPathFile.exists()) {
						mapperPathFile.mkdir();
					}
					String mapperFileName = mapperPath + File.separator + simpleName + ".mapper.xml";
					File mapperFile = new File(mapperFileName);
					if (mapperFile.exists())
						continue;

					Method [] methods = clazz.getSuperclass().getDeclaredMethods();
					String resultMap = "";//resultMap代码	{3}
					String resultMapTemplate = "<result property=\"{{1}}\" column=\"{{2}}\" jdbcType=\"{{3}}\"/>";//resultMap模板

					String selectField = "*";//查询字段	{5}

					String add = "";//插入语句	{6}
					String addColumn = "";
					String addProperty = "";

					String update = "";//更新语句	{7}

					for(Method method:methods){
						String name = method.getName();
						//仅取出所有get方法，且不包含ID，不包含枚举映射
						if(!name.substring(0, 3).equals("get") || name.equals("getId") || name.substring(name.length() - 5).equals("Value"))
							continue;
						boolean isEnum = isEnum(method.getReturnType());
						String property = name.substring(3);//属性名	{{1}}
						String column = getColumn(property);//字段名	{{2}}
						property = property.substring(0, 1).toLowerCase() + property.substring(1);//将属性第一字母小写
						if(isEnum) //如果是枚举，需要修改属性名
							property = property + "Value";
						String jdbcType = getJdbcType(method.getReturnType());//{{3}}
						resultMap = resultMap + "\t\t" + resultMapTemplate.replace("{{1}}", property).replace("{{2}}", column).replace("{{3}}", jdbcType) + "\n";

						addColumn = addColumn + "\t\t" + column + ",\n";
						addProperty = addProperty +  "\t\t#{" + property + "},\n";

						update = update + "\t\t" + column + "=#{" + property + "},\n";
					}
					resultMap = resultMap.substring(2, resultMap.length() - 1);//去掉开始的制表符、结尾换行符
					addColumn = addColumn.substring(0, addColumn.length() - 2) + "\n";//去掉结尾的逗号
					addProperty = addProperty.substring(0, addProperty.length() - 2);//去掉结尾的逗号
					add = "(\n" +
							addColumn +
							"\t\t) values (\n" +
							addProperty +
							" )";
					update = update.substring(2, update.length() - 2);//去掉开始的制表符、结尾的逗号

					String content = getMapperTemplate().replace("{1}", entityName).replace("{2}", simpleName).replace("{3}", resultMap).replace("{5}", selectField).replace("{6}", add).replace("{7}", update);

					//生成mapper文件
					BufferedWriter out = new BufferedWriter(new FileWriter(mapperFile));
					out.write(content, 0, content.length());
					out.flush();
					out.close();
				} catch (Exception e) {
					e.printStackTrace();
				}
			}
		}
	}

	/**
	 * 将驼峰转换为下划线
	 * @param name
	 * @return
	 */
	private static String getColumn(String name) {
		StringBuilder result = new StringBuilder();
		if (name != null && name.length() > 0) {
			for (int i = 0; i < name.length(); i++) {
				String s = name.substring(i, i + 1);
				// 在大写字母前添加下划线
				if (i!= 0 && s.equals(s.toUpperCase()) && !Character.isDigit(s.charAt(0)))
					result.append("_");
				// 其他字符直接转成大写
				result.append(s.toUpperCase());
			}
		}
		String column = result.toString();
		if(column.equals("VALID"))
			column = "IS_VALID";

		return column;
	}

	/**
	 * 是否为枚举
	 * @param clazz
	 * @return
	 */
	private boolean isEnum(Class<?> clazz){
		if(clazz != null) {
			if (clazz.isEnum())
				return true;
		}
		return false;
	}

	/**
	 * 根据方法类型获得映射类型
	 * @param type
	 * @return
	 */
	private String getJdbcType(Object type){
		if(type.toString().contains("Long"))
			return "BIGINT";
		else if(type.toString().contains("Integer"))
			return "INTEGER";
		else if(type.toString().contains("String"))
			return "VARCHAR";
		else if(type.toString().contains("Boolean"))
			return "TINYINT";
		else if(type.toString().contains("Double"))
			return "DOUBLE";
		else if(type.toString().contains("Date"))
			return "TIMESTAMP";
		else if(type.toString().contains("BigDecimal"))
			return "DECIMAL";
		else if(isEnum((Class<?>) type)){
			return "TINYINT";
		}
		return "";
	}

	/**
	 * 获得mapper文件模板
	 * @return
	 */
	private String getMapperTemplate(){
		return "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" +
				"<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\"\n" +
				"\t\t\"classpath://mybatis-3-mapper.dtd\">\n" +
				"<mapper namespace=\"{1}\">\n" +
				"\t<resultMap type=\"{2}\" id=\"resultMap\">\n" +
				"\t\t<id property=\"id\" column=\"ID\"/> \n" +
				"\t\t{3}\n" +
				"\t</resultMap>\n" +
				"\n" +
				"\t<sql id=\"table\"></sql>\n" +
				"\t\n" +
				"\t<sql id=\"selectField\">\n" +
				"\t\t{5}\n" +
				"\t</sql>\n" +
				"\t\n" +
				"\t<insert id=\"add\" parameterType=\"{2}\">\n" +
				"\t\t<selectKey resultType=\"Integer\" keyProperty=\"id\" order=\"AFTER\">\n" +
				"\t\t\tSELECT\n" +
				"\t\t\tLAST_INSERT_ID() AS ID\n" +
				"\t\t</selectKey>\n" +
				"\t\tinsert into <include refid=\"table\" />\n" +
				"\t\t{6}\n" +
				"\t</insert>\n" +
				"\t\n" +
				"\t<delete id=\"deleteById\" parameterType=\"Integer\">\n" +
				"\t\tdelete from <include refid=\"table\" /> where ID=#{id}\n" +
				"\t</delete>\n" +
				"\t\n" +
				"\t<update id=\"update\" parameterType=\"{2}\">\n" +
				"\t\tupdate <include refid=\"table\" />\n" +
				"\t\t<set>\n" +
				"\t\t{7}\n" +
				"\t\t</set>\n" +
				"\t\twhere ID=#{id}\n" +
				"\t</update>\n" +
				"\t\n" +
				"\t<select id=\"queryAll\" resultMap=\"resultMap\">\n" +
				"\t\tselect <include refid=\"selectField\" /> from <include refid=\"table\" />\n" +
				"\t</select>\n" +
				"\t\n" +
				"\t<select id=\"query\" resultMap=\"resultMap\" parameterType=\"java.util.HashMap\">\n" +
				"\t\tselect\n" +
				"\t\t<include refid=\"selectField\" />\n" +
				"\t\tfrom <include refid=\"table\" />\n" +
				"\t\t<where>\n" +
				"\t\t\t1=1\n" +
				"\t\t\t\n" +
				"\t\t</where>\n" +
				"\t\torder by ID desc\n" +
				"\t</select>\n" +
				"\t\n" +
				"\t<!-- 根据ID查询 -->\n" +
				"\t<select id=\"getById\" parameterType=\"Integer\" resultMap=\"resultMap\">\n" +
				"\t\tselect <include refid=\"selectField\" /> from <include refid=\"table\" /> where ID=#{id}\n" +
				"\t</select>\n" +
				"</mapper>";
	}

	/**
	 * 生成dao层代码
	 */
	private void generatorDao(String className) {
		String entityImport = packageName + File.separator + "entity"
				+ File.separator + className;
		entityImport = entityImport.replace(File.separator, ".");

		String daoPackageName = packageName + File.separator + "dao";
		daoPackageName = daoPackageName.replace(File.separator, ".");

		String daoName = className + "Dao";
		String daoPath = basePackage + File.separator + "dao";
		File daoPackage = new File(daoPath);
		if (!daoPackage.exists())
			daoPackage.mkdir();

		String daoFileName = daoPath + File.separator + daoName + ".java";
		File daoFile = new File(daoFileName);
		if (daoFile.exists())
			return;

		try {
			BufferedWriter out = new BufferedWriter(new FileWriter(daoFile));

			out.write("package " + daoPackageName + ";");
			out.newLine();
			out.newLine();
			out.write("import " + entityImport + ";");
			out.newLine();
			out.write("import " + BASE_DAO_NAME + ";");
			out.newLine();
			out.newLine();
			out.write("public interface " + daoName + " extends BaseDao<"
					+ className + "> {");
			out.newLine();
			out.newLine();
			out.write("}");

			out.flush();
			out.close();

		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	/**
	 * 生成daoImpl层代码
	 */
	private void generatorDaoImpl(String className) {
		String entityImport = packageName + File.separator + "entity"
				+ File.separator + className;
		entityImport = entityImport.replace(File.separator, ".");

		String daoImplPackageName = packageName + File.separator + "dao"
				+ File.separator + "impl";
		daoImplPackageName = daoImplPackageName.replace(File.separator, ".");

		String daoName = className + "Dao";

		String daoImport = packageName + File.separator + "dao"
				+ File.separator + daoName;
		daoImport = daoImport.replace(File.separator, ".");

		String daoImplName = daoName + "Impl";
		String daoImplPath = basePackage + File.separator + "dao"
				+ File.separator + "impl";
		File daoImplPackage = new File(daoImplPath);
		if (!daoImplPackage.exists())
			daoImplPackage.mkdir();

		String daoImplFileName = daoImplPath + File.separator + daoImplName
				+ ".java";
		File daoFile = new File(daoImplFileName);
		if (daoFile.exists())
			return;

		try {
			BufferedWriter out = new BufferedWriter(new FileWriter(daoFile));

			out.write("package " + daoImplPackageName + ";");
			out.newLine();
			out.newLine();
			out.write("import org.springframework.stereotype.Repository;");
			out.newLine();
			out.newLine();
			out.write("import " + entityImport + ";");
			out.newLine();
			out.write("import " + daoImport + ";");
			out.newLine();
			out.write("import " + BASE_DAO_IMPL_NAME + ";");
			out.newLine();
			out.newLine();
			out.write("@Repository(\"" + toLowCase(daoName) + "\")");
			out.newLine();
			out.write("public class " + daoImplName + " extends BaseDaoImpl<"
					+ className + "> implements " + daoName + " {");
			out.newLine();
			out.newLine();
			out.write("}");

			out.flush();
			out.close();

		} catch (IOException e) {
			e.printStackTrace();
		}

	}

	/**
	 * 生成service层代码
	 */
	private void generatorService(String className) {
		String entityImport = packageName + File.separator + "entity"
				+ File.separator + className;
		entityImport = entityImport.replace(File.separator, ".");

		String servicePackageName = packageName + File.separator + "service";
		servicePackageName = servicePackageName.replace(File.separator, ".");

		String daoName = className + "Dao";

		String daoImport = packageName + File.separator + "dao"
				+ File.separator + daoName;
		daoImport = daoImport.replace(File.separator, ".");

		String serviceName = className + "Service";
		String servicePath = basePackage + File.separator + "service";
		File servicePackage = new File(servicePath);
		if (!servicePackage.exists())
			servicePackage.mkdir();

		String serviceFileName = servicePath + File.separator + serviceName
				+ ".java";
		File daoFile = new File(serviceFileName);
		if (daoFile.exists())
			return;

		try {
			BufferedWriter out = new BufferedWriter(new FileWriter(daoFile));

			out.write("package " + servicePackageName + ";");
			out.newLine();
			out.newLine();
			out.write("import org.springframework.stereotype.Service;");
			out.newLine();
			out.write("import org.springframework.transaction.annotation.Propagation;");
			out.newLine();
			out.write("import org.springframework.transaction.annotation.Transactional;");
			out.newLine();
			out.write("import javax.annotation.Resource;");
			out.newLine();
			out.newLine();
			out.write("import " + entityImport + ";");
			out.newLine();
			out.write("import " + daoImport + ";");
			out.newLine();
			out.write("import " + BASE_DAO_NAME + ";");
			out.newLine();
			out.write("import " + BASE_SERVICE_IMPL_NAME + ";");
			out.newLine();
			out.newLine();
			out.write("@Service(\"" + toLowCase(serviceName) + "\")");
			out.newLine();
			out.write("@Transactional(rollbackForClassName={\"java.lang.Exception\"}, propagation = Propagation.REQUIRED)");
			out.newLine();
			out.write("public class " + serviceName
					+ " extends BaseServiceImpl<" + className + "> {");
			out.newLine();
			out.newLine();
			out.write("\t@Resource");
			out.newLine();
			out.write("\tprivate " + daoName + " " + toLowCase(daoName) + ";");
			out.newLine();
			out.newLine();
			out.write("\t@Override");
			out.newLine();
			out.write("\tprotected BaseDao <" + className + "> getDao() {");
			out.newLine();
			out.write("\t\treturn " + toLowCase(daoName) + ";");
			out.newLine();
			out.write("\t}");
			out.newLine();
			out.write("}");

			out.flush();
			out.close();

		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	/**
	 * 获得工程所在硬盘路径
	 *
	 * @return
	 */
	public String getProjectPath() {
		try {
			File directory = new File("");// 参数为空
			String courseFile = directory.getCanonicalPath() + File.separator + PROJECT_NAME;

			return courseFile;
		} catch (IOException e) {
			return null;
		}
	}

	/**
	 * 将首字母转为小写
	 *
	 * @param content
	 * @return
	 */
	private String toLowCase(String content) {
		content = content.substring(0, 1).toLowerCase() + content.substring(1);
		return content;
	}

	public String getBasePackage() {
		return basePackage;
	}
}
